{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\";   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using Keras version: 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import ktrain\n",
    "from ktrain import text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Data Into Arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of training set: 2257\n",
      "size of validation set: 1502\n",
      "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n"
     ]
    }
   ],
   "source": [
    "categories = ['alt.atheism', 'soc.religion.christian',\n",
    "             'comp.graphics', 'sci.med']\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "train_b = fetch_20newsgroups(subset='train',\n",
    "   categories=categories, shuffle=True, random_state=42)\n",
    "test_b = fetch_20newsgroups(subset='test',\n",
    "   categories=categories, shuffle=True, random_state=42)\n",
    "\n",
    "print('size of training set: %s' % (len(train_b['data'])))\n",
    "print('size of validation set: %s' % (len(test_b['data'])))\n",
    "print('classes: %s' % (train_b.target_names))\n",
    "\n",
    "x_train = train_b.data\n",
    "y_train = train_b.target\n",
    "x_test = test_b.data\n",
    "y_test = test_b.target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 1: Preprocess Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing train...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing test...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trn, val, preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n",
    "                                          x_test=x_test, y_test=y_test,\n",
    "                                          class_names=train_b.target_names,\n",
    "                                          preprocess_mode='distilbert',\n",
    "                                          maxlen=350)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 2:  Build a Model and Wrap in Learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fasttext: a fastText-like model [http://arxiv.org/pdf/1607.01759.pdf]\n",
      "logreg: logistic regression using a trainable Embedding layer\n",
      "nbsvm: NBSVM model [http://www.aclweb.org/anthology/P12-2018]\n",
      "bigru: Bidirectional GRU with pretrained word vectors [https://arxiv.org/abs/1712.09405]\n",
      "standard_gru: simple 2-layer GRU with randomly initialized embeddings\n",
      "bert: Bidirectional Encoder Representations from Transformers (BERT) [https://arxiv.org/abs/1810.04805]\n",
      "distilbert: distilled, smaller, and faster BERT from Hugging Face [https://arxiv.org/abs/1910.01108]\n"
     ]
    }
   ],
   "source": [
    "text.print_text_classifiers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is Multi-Label? False\n",
      "maxlen is 350\n",
      "done.\n"
     ]
    }
   ],
   "source": [
    "model = text.text_classifier('distilbert', train_data=trn, preproc=preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 3: Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "begin training using onecycle policy with max lr of 3e-05...\n",
      "Train for 377 steps, validate for 251 steps\n",
      "Epoch 1/4\n",
      "377/377 [==============================] - 69s 183ms/step - loss: 0.6874 - accuracy: 0.7882 - val_loss: 0.3342 - val_accuracy: 0.8948\n",
      "Epoch 2/4\n",
      "377/377 [==============================] - 58s 154ms/step - loss: 0.1288 - accuracy: 0.9659 - val_loss: 0.2021 - val_accuracy: 0.9341\n",
      "Epoch 3/4\n",
      "377/377 [==============================] - 60s 159ms/step - loss: 0.0538 - accuracy: 0.9849 - val_loss: 0.1343 - val_accuracy: 0.9621\n",
      "Epoch 4/4\n",
      "377/377 [==============================] - 61s 161ms/step - loss: 0.0176 - accuracy: 0.9973 - val_loss: 0.1415 - val_accuracy: 0.9634\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fbdc5e43e80>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.fit_onecycle(3e-5, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict on New Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = ktrain.get_predictor(model, preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'comp.graphics'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p.predict(\"There is a problem with my computer monitor's resolution.  Everything is blurry.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
